{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify`\n",
    "\n",
    "参考：`tvm/tests/python/arith/test_arith_canonical_simplify.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ..\n",
    "import testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import te\n",
    "\n",
    "class CanonicalChecker:\n",
    "    def __init__(self):\n",
    "        self.analyzer = tvm.arith.Analyzer()\n",
    "\n",
    "    def verify(self, data, expected):\n",
    "        res = self.analyzer.canonical_simplify(data)\n",
    "        expected = tvm.runtime.convert(expected)\n",
    "        assert tvm.ir.structural_equal(res, expected), \"\\ndata={}\\nres={}\\nexpected={}\".format(\n",
    "            data, res, expected\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` mul+sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "x, y, z = te.var(\"x\"), te.var(\"y\"), te.var(\"z\")\n",
    "\n",
    "ck.verify(2 + (3 * x + z + y + 1) * 4 + x, x * 13 + z * 4 + y * 4 + 6)\n",
    "ck.verify(x * 3 - 4 * x + 1, 1 - x)\n",
    "ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)\n",
    "tdiv = tvm.tir.truncdiv\n",
    "tmod = tvm.tir.truncmod\n",
    "# trucdiv\n",
    "ck.verify(tdiv(x + y + x + y * 3, 2), y * 2 + x)\n",
    "ck.verify(tmod(x + y + x + y * 3, 2), 0)\n",
    "\n",
    "# floordiv\n",
    "fld = tvm.te.floordiv\n",
    "flm = tvm.te.floormod\n",
    "ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2))\n",
    "ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x)\n",
    "ck.verify(flm(x + y + x + y * 3, 2), 0)\n",
    "ck.verify(fld(x + x + y * 3, 2), fld(y * 3, 2) + x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` plit_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "x, y, z = te.var(\"x\"), te.var(\"y\"), te.var(\"z\")\n",
    "\n",
    "# trucdiv\n",
    "tdiv = tvm.tir.truncdiv\n",
    "tmod = tvm.tir.truncmod\n",
    "\n",
    "# split div const\n",
    "ck.verify(tdiv(x, 3) * 3 + tmod(x, 3), x)\n",
    "ck.verify(tdiv(x, 6) * 6 + tmod(tdiv(x, 3), 2) * 3 + tmod(x, 3), x)\n",
    "ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 4), tdiv(tmod(x, 16), 4))\n",
    "ck.verify(tdiv(tmod(x, 2), 8), 0)\n",
    "ck.verify(tdiv(tmod(x, 2), 7), 0)\n",
    "ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 6), tdiv(tmod(x, 16), 6))\n",
    "\n",
    "# split mod const\n",
    "ck.verify(tmod((x * 8), 16), tmod(x, 2) * 8)\n",
    "ck.verify(tmod(x * 8, 2), 0)\n",
    "\n",
    "# simplify then fold\n",
    "ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))\n",
    "ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000))\n",
    "ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)\n",
    "# complex fold\n",
    "ck.verify(tdiv(z * 9 + y, 2) * 2 + tmod(z * 9 + y, 2), z * 9 + y)\n",
    "\n",
    "ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True)\n",
    "ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True)\n",
    "ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)\n",
    "\n",
    "# floordiv\n",
    "fld = tvm.te.floordiv\n",
    "flm = tvm.te.floormod\n",
    "ck.verify(fld(x * 5, 2), fld(x * 5, 2))\n",
    "ck.verify(fld(x, 3) * 3 + flm(x, 3), x)\n",
    "ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x)\n",
    "ck.verify(fld(fld(flm(x, 16), 2) * 2, 4), fld(flm(x, 16), 4))\n",
    "ck.verify(fld(flm(x, 2), 8), 0)\n",
    "ck.verify(fld(flm(x, 2), 7), 0)\n",
    "ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))\n",
    "\n",
    "# cannot simplify mixed case, unless we canonicalize into one mode.\n",
    "ck.verify(tdiv(x, 6) * 2 + tmod(fld(x, 3), 2), tdiv(x, 6) * 2 + tmod(fld(x, 3), 2))\n",
    "\n",
    "ck.verify(tmod(-x, 2), tmod(x, -2) * -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` div"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "x = te.var(\"x\")\n",
    "tdiv = tvm.tir.truncdiv\n",
    "\n",
    "# truc div\n",
    "ck.verify(tdiv(16 + 48 * x, 16), x * 3 + 1)\n",
    "# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0\n",
    "# (17+48*x)/16 != 1+3*x\n",
    "ck.verify(tdiv(17 + 48 * x, 16), tdiv(x * 48 + 17, 16))\n",
    "# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified\n",
    "ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10))\n",
    "ck.verify(tdiv(17 + 48 * x, 16), x * 3 + 1)\n",
    "# Trying expressions that are not simplifiable for any values of the variables\n",
    "ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16))\n",
    "\n",
    "# floordiv\n",
    "fld = tvm.te.floordiv\n",
    "ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True)\n",
    "ck.verify(fld(16 + 48 * x, 16), x * 3 + 1)\n",
    "ck.verify(fld(17 + 48 * x, 16), x * 3 + 1)\n",
    "ck.verify(fld(17 + 47 * x, 16), fld(x * 47 + 17, 16))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` fp16_const_fold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "zero = tvm.tir.const(0, \"float16\")\n",
    "one = tvm.tir.const(1, \"float16\")\n",
    "half = tvm.tir.const(0.5, \"float16\")\n",
    "\n",
    "ck.verify(zero + half, half)\n",
    "ck.verify(half - zero, half)\n",
    "\n",
    "ck.verify(zero * half, zero)\n",
    "ck.verify(half * one, half)\n",
    "\n",
    "ck.verify(half / one, half)\n",
    "ck.verify(zero / half, zero)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`tvm.arith.Analyzer.canonical_simplify` floormod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "flm = tvm.te.floormod\n",
    "x, y = te.var(\"x\"), te.var(\"y\")\n",
    "ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) + y + 12, 16))\n",
    "ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4)\n",
    "\n",
    "ck.verify(flm(-x, 2), flm(x, -2) * -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` canonical_mixed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "x = te.var(\"x\")\n",
    "z = tvm.tir.const(3, \"int32\")\n",
    "tdiv = tvm.tir.truncdiv\n",
    "tmod = tvm.tir.truncmod\n",
    "ck.verify(tdiv(x, (z * z)) - tdiv(x, (z * z)), 0)\n",
    "ck.verify(tdiv(x, (z + z)) - tdiv(x, (z + z)), 0)\n",
    "ck.verify(x - 2 < 3, x < 5)\n",
    "ck.verify(tvm.te.max(x, 1) - tvm.te.max(x, 1), 0)\n",
    "ck.verify(tvm.te.min(x, 1) - tvm.te.min(x, 1), 0)\n",
    "ck.verify(x * x - x * x, 0)\n",
    "ck.verify(tmod(tdiv(tmod(x, 20), 2) * 2, 4), tdiv(tmod(x, 4), 2) * 2)\n",
    "\n",
    "fld = tvm.te.floordiv\n",
    "ck.verify(fld(x, (z * z)) - fld(x, (z * z)), 0)\n",
    "ck.verify(fld(x, (z + z)) - fld(x, (z + z)), 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` reduce_combiner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "dummy = te.var(\"dummy\")\n",
    "comm_reducer = te.comm_reducer\n",
    "prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.tir.const(1, t0))\n",
    "\n",
    "sum_or_prod = comm_reducer(\n",
    "    lambda x, y: tvm.tir.Select(dummy < 0, x + y, x * y),\n",
    "    lambda t0: tvm.tir.Select(dummy < 0, tvm.tir.const(0, t0), tvm.tir.const(1, t0)),\n",
    ")\n",
    "sum_and_prod = comm_reducer(\n",
    "    lambda x, y: (x[0] + y[0], x[1] * y[1]),\n",
    "    lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t1) - tvm.tir.const(4, t1)),\n",
    ")\n",
    "some_reducer1 = comm_reducer(\n",
    "    lambda x, y: (\n",
    "        x[0] + y[0],\n",
    "        x[0] + y[0] + x[1] + y[1],\n",
    "        x[0] * y[2] + y[0] * x[2],\n",
    "        x[1] + y[2],\n",
    "        4.0,\n",
    "    ),\n",
    "    lambda t0, t1, t2, t3, t4: (\n",
    "        tvm.tir.const(0, t0),\n",
    "        tvm.tir.const(1, t1),\n",
    "        tvm.tir.const(2, t2),\n",
    "        tvm.tir.const(3, t3),\n",
    "        tvm.tir.const(4, t4),\n",
    "    ),\n",
    ")\n",
    "\n",
    "k = te.reduce_axis((0, 10), name=\"k\")\n",
    "A = te.placeholder((10,), name=\"A\")\n",
    "# Test that SimplifyCombiner makes use of vranges\n",
    "ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4))\n",
    "ck.verify(sum_or_prod(A[k], k), te.sum(A[k], k))\n",
    "ck.verify(sum_or_prod(A[k], k, init=1), te.sum(A[k], k, init=1))\n",
    "ck.analyzer.update(dummy, tvm.arith.ConstIntBound(5, 9), True)\n",
    "ck.verify(sum_or_prod(A[k], k), prod(A[k], k))\n",
    "ck.verify(sum_or_prod(A[k], k, init=1), prod(A[k], k, init=1))\n",
    "ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True)\n",
    "ck.verify(sum_and_prod((A[k], A[10 - k]), k)[0], te.sum(A[k], k))\n",
    "ck.verify(sum_and_prod((A[k], A[10 - k]), k)[1], prod(A[10 - k], k))\n",
    "\n",
    "reference_simplified_sources = [\n",
    "    [A[0]],\n",
    "    [A[0], A[1]],\n",
    "    [A[0], A[2]],\n",
    "    [A[0], A[1], A[2], A[3]],\n",
    "    [A[4]],\n",
    "]\n",
    "for j in range(5):\n",
    "    # Here we use the j-th component of the result, so only it and the components it\n",
    "    # depends on are left.\n",
    "    simplified = ck.analyzer.canonical_simplify(\n",
    "        some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j]\n",
    "    )\n",
    "\n",
    "    # Check that the remaining components are the expected ones.\n",
    "    for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):\n",
    "        assert tvm.ir.structural_equal(lhs, rhs)\n",
    "\n",
    "# Test that components with side effects are not removed\n",
    "dummy = tvm.ir.GlobalVar(\"dummy\")\n",
    "side_effect = lambda *xs: tvm.tir.Call(\"int32\", dummy, xs)\n",
    "ck.verify(\n",
    "    sum_and_prod((A[k], side_effect(A[10 - k])), k)[0],\n",
    "    sum_and_prod((A[k], side_effect(A[10 - k])), k)[0],\n",
    ")\n",
    "ck.verify(sum_and_prod((side_effect(A[k]), A[10 - k]), k)[0], te.sum(side_effect(A[k]), k))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` reduce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "k = te.reduce_axis((0, 10), name=\"k\")\n",
    "j = te.reduce_axis((-5, 3), name=\"j\")\n",
    "A = te.placeholder((10,), name=\"A\")\n",
    "ck.verify(te.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]), te.sum(k + j, [k, j]))\n",
    "ck.verify(te.sum(A[3], []), A[3])\n",
    "ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, dtype=\"float32\"))\n",
    "# The rule below is not typical, removed for now\n",
    "ck.verify(te.sum(te.div(k, 10), k), te.sum(tvm.tir.const(0, \"int32\"), k))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` if_then_else"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "x = te.var(\"x\")\n",
    "y = te.var(\"y\")\n",
    "tdiv = tvm.tir.truncdiv\n",
    "tmod = tvm.tir.truncmod\n",
    "# simplification that takes condition into account.\n",
    "res = tvm.tir.if_then_else(\n",
    "    (x * 4 + y) >= 466036,\n",
    "    tvm.tir.if_then_else(\n",
    "        24512 <= tmod(((x * 4) + y) - 466036, 24528),\n",
    "        tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16),\n",
    "        x,\n",
    "    ),\n",
    "    y,\n",
    ")\n",
    "\n",
    "res2 = tvm.tir.if_then_else(\n",
    "    (x * 4) >= 466036 - y,\n",
    "    tvm.tir.if_then_else(\n",
    "        24512 <= tmod(((x * 4) + y) - 466036, 24528),\n",
    "        tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16),\n",
    "        x,\n",
    "    ),\n",
    "    y,\n",
    ")\n",
    "expected = tvm.tir.if_then_else(\n",
    "    tvm.tir.LE(466036, (x * 4 + y)),\n",
    "    tvm.tir.if_then_else(\n",
    "        tvm.tir.LE(24512, tmod(((x * 4) + y) - 4, 24528)), tmod(((x * 4) + y) - 4, 16), x\n",
    "    ),\n",
    "    y,\n",
    ")\n",
    "ck.verify(res, expected)\n",
    "ck.verify(res2, expected)\n",
    "# can only simplify if condition\n",
    "res = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))\n",
    "expected = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))\n",
    "ck.verify(res, ck.analyzer.canonical_simplify(expected))\n",
    "\n",
    "res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) > 2, x, 0), 0)\n",
    "expected = tvm.tir.Select(x >= 10, x, 0)\n",
    "ck.verify(res, ck.analyzer.canonical_simplify(expected))\n",
    "\n",
    "res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) < 2, x, 0), 0)\n",
    "ck.verify(res, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` complex_cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "x = te.var(\"x\")\n",
    "y = te.var(\"y\")\n",
    "tdiv = tvm.tir.truncdiv\n",
    "tmod = tvm.tir.truncmod\n",
    "res2 = (\n",
    "    tdiv(tdiv(tmod(x * 128 + y, 1296), 36) * 2 + 1, 2) * 36\n",
    "    + tdiv(tmod((x * 128) + y, 36) * 2 + 1, 2)\n",
    "    - tmod((x * 128) + y, 1296)\n",
    "    + 1\n",
    ")\n",
    "ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5))\n",
    "ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))\n",
    "ck.verify(res2, 1)\n",
    "\n",
    "ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True)\n",
    "res3 = (\n",
    "    tdiv(x * 1024 + y, 65536)\n",
    "    + tdiv(tmod(x * 1024 + y, 65536), 256)\n",
    "    + tdiv(tmod(x * 1024 + y, 256), 16)\n",
    "    + tmod(x * 1024 + y, 16)\n",
    "    - tdiv(y, 256)\n",
    "    - tdiv(tmod(y, 256), 16)\n",
    "    - tmod(y, 16)\n",
    "    - (x * 4)\n",
    ")\n",
    "ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` cast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "tcast = tvm.tir.Cast\n",
    "fld = tvm.te.floordiv\n",
    "flm = tvm.te.floormod\n",
    "# cast(i64, i + j + 1) - cast(i64, i)\n",
    "i = te.var(\"i\", dtype=\"int32\")\n",
    "j = te.var(\"j\", dtype=\"int32\")\n",
    "res = tcast(\"int64\", i + j + 1) - tcast(\"int64\", i)\n",
    "ck.verify(res, tcast(\"int64\", j) + tvm.tir.const(1, \"int64\"))\n",
    "# cast(i32, i + j + 1) - cast(i32, i)\n",
    "i = te.var(\"i\", dtype=\"int64\")\n",
    "j = te.var(\"j\", dtype=\"int64\")\n",
    "ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))\n",
    "ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))\n",
    "res = tcast(\"int32\", i + j + 1) - tcast(\"int32\", i)\n",
    "ck.verify(res, tcast(\"int32\", j) + 1)\n",
    "# cast(i32, i + j - 100)\n",
    "i = te.var(\"i\", dtype=\"int64\")\n",
    "j = te.var(\"j\", dtype=\"int64\")\n",
    "ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2**31 - 1))\n",
    "ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))\n",
    "res = tcast(\"int32\", i + j - 100)\n",
    "ck.verify(res, res)\n",
    "# cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32\n",
    "# - cast(i32, flm(axis, 7i64) * 2i64)\n",
    "axis = te.var(\"axis\", dtype=\"int64\")\n",
    "ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))\n",
    "res = (\n",
    "    tcast(\n",
    "        \"int32\",\n",
    "        flm(axis, tvm.tir.const(7, \"int64\")) * tvm.tir.const(2, \"int64\")\n",
    "        + tvm.tir.const(1, \"int64\"),\n",
    "    )\n",
    "    + tvm.tir.const(1, \"int32\")\n",
    "    - tcast(\"int32\", flm(axis, tvm.tir.const(7, \"int64\")) * tvm.tir.const(2, \"int64\"))\n",
    ")\n",
    "ck.verify(res, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` normalize_min_value_expr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "x = te.var(\"x\", \"int32\")\n",
    "\n",
    "ck.verify(te.min_value(\"int32\") - x == 0, x == te.min_value(\"int32\"))\n",
    "ck.verify(te.min_value(\"int32\") + x == 0, False)\n",
    "ck.verify(0 == te.min_value(\"int32\") - x, x == te.min_value(\"int32\"))\n",
    "ck.verify(0 == te.min_value(\"int32\") + x, False)\n",
    "ck.verify(-x + te.min_value(\"int32\") == 0, x == te.min_value(\"int32\"))\n",
    "ck.verify(x + te.min_value(\"int32\") == 0, False)\n",
    "ck.verify(0 == -x + te.min_value(\"int32\"), x == te.min_value(\"int32\"))\n",
    "ck.verify(0 == x + te.min_value(\"int32\"), False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` proddiv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "flm = tvm.te.floormod\n",
    "fld = tvm.te.floordiv\n",
    "tdiv = tvm.te.truncdiv\n",
    "tmod = tvm.te.truncmod\n",
    "\n",
    "x, y, z = te.var(\"x\"), te.var(\"y\"), te.var(\"y\")\n",
    "\n",
    "ck.verify(flm(x * 32 * x, x), 0)\n",
    "ck.verify(flm(z * x * 32 * x * y, x * z), 0)\n",
    "ck.verify(flm(z * x * 32 * x * y, x * z * y * 8 * x), 0)\n",
    "ck.verify(flm(z * x * 32 * (x * y), 6 * x * z), flm(x * y * 16, 3) * (x * z * 2))\n",
    "ck.verify(flm(x * 32 * x, x * z), flm(x * 32, z) * x)\n",
    "\n",
    "ck.verify(tmod(x * 32 * x, x), 0)\n",
    "ck.verify(tmod(z * x * 32 * x * y, x * z), 0)\n",
    "ck.verify(tmod(z * x * 32 * (x * y), 6 * x * z), tmod(x * y * 16, 3) * (x * z * 2))\n",
    "ck.verify(tmod(x * 32 * x, x * z), tmod(x * 32, z) * x)\n",
    "\n",
    "ck.verify(fld(x * 2 * x * z, 4 * x * x * x), fld(z, x * 2))\n",
    "ck.verify(fld(x * (2 * y) * 3, 3 * y), x * 2)\n",
    "ck.verify(fld(x * (2 * y) * 3, 3 * y * z), fld(x * 2, z))\n",
    "\n",
    "ck.verify(tdiv(x * 2 * x * z, 4 * x * x * x), tdiv(z, x * 2))\n",
    "ck.verify(tdiv(x * (2 * y) * 3, 3 * y), x * 2)\n",
    "ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` floormod_two"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "flm = tvm.te.floormod\n",
    "x, y = te.var(\"x\"), te.var(\"y\")\n",
    "ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.canonical_simplify` le"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ck = CanonicalChecker()\n",
    "# Case 1. Ignore the extra expr if it's small than the division number\n",
    "x, y, z = te.var(\"x\"), te.var(\"y\"), te.var(\"z\")\n",
    "ck.analyzer.bind(y, tvm.ir.Range(0, 8))\n",
    "ck.analyzer.bind(z, tvm.ir.Range(0, 2))\n",
    "ck.verify(x * 8 + y < 16, x < 2)\n",
    "ck.verify(x * 8 + z * 4 < 16, x < 2)\n",
    "ck.verify(x * 8 + z * 4 < 16, x < 2)\n",
    "\n",
    "# TODO: Not sure why `-2 < x` will be convert to `x > -2`, use a explicit simplify here.\n",
    "ck.verify(x * -8 + y < 16, ck.analyzer.rewrite_simplify(-2 < x))\n",
    "ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x))\n",
    "\n",
    "ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)\n",
    "ck.verify(x * 8 + y - z < 16, x < 2)\n",
    "\n",
    "n = te.size_var(\"n\")\n",
    "ck.verify(x * 8 + y < n, x * 8 + y < n)\n",
    "\n",
    "# Case 2. Simplify the extra expr\n",
    "x1, x2, ty, tx, vec = (\n",
    "    tvm.te.var(\"x1\"),\n",
    "    tvm.te.var(\"x2\"),\n",
    "    tvm.te.var(\"ty\"),\n",
    "    tvm.te.var(\"tx\"),\n",
    "    tvm.te.var(\"vec\"),\n",
    ")\n",
    "ck.analyzer.bind(x1, tvm.ir.Range(0, 2))\n",
    "ck.analyzer.bind(x2, tvm.ir.Range(0, 3))\n",
    "ck.analyzer.bind(ty, tvm.ir.Range(0, 8))\n",
    "ck.analyzer.bind(tx, tvm.ir.Range(0, 32))\n",
    "ck.analyzer.bind(vec, tvm.ir.Range(0, 8))\n",
    "ck.verify(\n",
    "    x1 * 5632 + (((x2 * 8 + ty) * 32 + tx) * 8 + vec) % 5632 < 11008,\n",
    "    x1 * 22 + (x2 * 8 + ty) % 22 < 43,\n",
    ")\n",
    "ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8)\n",
    "\n",
    "# Case 3. No failure\n",
    "x, y, z = te.var(\"x\"), te.var(\"y\"), te.var(\"z\")\n",
    "ck.analyzer.bind(y, tvm.ir.Range(0, 1024))\n",
    "ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
